{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c5c63d791a48c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from typing import Optional, List\n",
    "import torch\n",
    "from torch import nn\n",
    "from labml import tracker\n",
    "import numpy as np\n",
    "from labml_helpers.module import Module\n",
    "\n",
    "logging.info(\"asdasd\")\n",
    "\n",
    "from labml_nn.utils import clone_module_list\n",
    "run_test = True\n",
    "\n",
    "torch.set_printoptions(precision=3,linewidth=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0be617ff",
   "metadata": {},
   "source": [
    "MHA实现细节\n",
    "* Q和K的长度不一定一样，KV长度一定是一样的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e24f751d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from os import name\n",
    "import re\n",
    "\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):\n",
    "        super().__init__()\n",
    "        self.d_k = d_model//heads\n",
    "        self.heads = heads\n",
    "\n",
    "        self.q_proj = nn.Linear(d_model, d_model, bias=bias)\n",
    "        self.k_proj = nn.Linear(d_model, d_model, bias=bias)\n",
    "        self.v_proj = nn.Linear(d_model, d_model, bias=bias)\n",
    "        self.output = nn.Linear(d_model, d_model)\n",
    "\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "\n",
    "        self.scale = 1/math.sqrt(self.d_k)\n",
    "\n",
    "        self.attn = None\n",
    "\n",
    "    def get_score(self, query: torch.Tensor, key: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Q shape:(seq_len_q, batch_size, num_heads, head_dim)\n",
    "        K shape:(seq_len_k, batch_size, num_heads, head_dim)\n",
    "        score shape:(seq_len_q, seq_len_k, batch_size, num_heads)\n",
    "        \"\"\"\n",
    "        return torch.einsum('qbhd,kbhd->qkbh', query, key)\n",
    "\n",
    "    def validate_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):\n",
    "        \"\"\"校验mask\n",
    "\n",
    "        Args:\n",
    "            mask (torch.Tensor): 整体的mask，shape：(seq_len_q, seq_len_k, batch_size)\n",
    "            query_shape (List[int]): q的形状\n",
    "            key_shape (List[int]): k的形状\n",
    "        \"\"\"\n",
    "        assert mask.dim() >= 2\n",
    "        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]\n",
    "        assert mask.shape[1] == key_shape[0]\n",
    "        if mask.dim() >= 3:\n",
    "            assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]\n",
    "        while mask.dim() < len(query_shape)+1:\n",
    "            mask.unsqueeze_(-1)\n",
    "        return mask\n",
    "        \n",
    "\n",
    "    def split_head_(self, x: torch.Tensor):\n",
    "        head_shape = x.shape[:-1]\n",
    "        x = x.view(*head_shape, self.heads, self.d_k)\n",
    "        return x\n",
    "\n",
    "    def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):\n",
    "        seq_len_q, batch_size, _ = query.shape\n",
    "\n",
    "        if mask is not None:\n",
    "            mask = self.validate_mask(mask, query.shape, key.shape)\n",
    "\n",
    "        query = self.split_head_(self.q_proj(query))\n",
    "        key = self.split_head_(self.k_proj(key))\n",
    "        value = self.split_head_(self.v_proj(value))\n",
    "\n",
    "        # (seq_len_q, seq_len_k, batch_size, num_heads)\n",
    "        scores = self.get_score(query, key)*self.scale\n",
    "        if mask is not None:\n",
    "            scores = scores.masked_fill(mask == 0, float('-inf'))\n",
    "\n",
    "        attn = self.softmax(scores)\n",
    "        attn = self.dropout(attn)\n",
    "        \n",
    "        x = torch.einsum(\"qkbh,kbhd->qbhd\", attn, value)\n",
    "        self.attn = attn.detach()\n",
    "        x = x.reshape(seq_len_q, batch_size, -1)\n",
    "        return self.output(x)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    heads, d_model, seq_len_q, batch = 4, 64, 5, 2\n",
    "    mha = MultiHeadAttention(heads=heads, d_model=d_model)\n",
    "\n",
    "    q, k, v = torch.randn((seq_len_q, batch, d_model)),torch.randn((seq_len_q, batch, d_model)),torch.randn((seq_len_q, batch, d_model))\n",
    "    \n",
    "    mask1 = torch.tril(torch.ones((seq_len_q,seq_len_q)))\n",
    "    print(mha(query=q, key=k, value=v, mask=mask1).shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e920ecf2",
   "metadata": {},
   "source": [
    "位置编码Positional Encoding\n",
    "\n",
    "register_buffer将张量注册为模型的缓冲区，它不会被优化器更新，可以选择是否成为state_dict。这一方法显然也会注册一个成员变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd0c6d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_positional_encoding(d_model: int, max_len: int = 5000) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    生成位置编码\n",
    "    参数:\n",
    "        d_model: 模型维度\n",
    "        max_len: 最大序列长度\n",
    "\n",
    "    返回:\n",
    "        torch.Tensor: 形状为(max_len, d_model)的位置编码张量\n",
    "    \"\"\"\n",
    "    position = torch.arange(max_len).unsqueeze(1)\n",
    "    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
    "                         (-math.log(10000.0) / d_model))\n",
    "\n",
    "    pe = torch.zeros(max_len, d_model)\n",
    "    pe[:, 0::2] = torch.sin(position * div_term)\n",
    "    pe[:, 1::2] = torch.cos(position * div_term)\n",
    "\n",
    "    return pe.unsqueeze(1).requires_grad_(False)\n",
    "\n",
    "\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "        self.register_buffer('positional_encodings',\n",
    "                             get_positional_encoding(d_model, max_len), False)\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        pe = self.positional_encodings[:x.shape[0]\n",
    "                                       ].detach().requires_grad_(False)\n",
    "        return self.dropout(x+pe)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccb801c3",
   "metadata": {},
   "source": [
    "Position-wise Feed-Forward Network (FFN)\n",
    "\n",
    "门控单元GLU被认为能够改善Transformer，它将一个门控的线性层结果与激活后的线性层结果相乘作为放大线性层的输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f2d773",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedForward(Module):\n",
    "    def __init__(self, d_model: int, d_ff: int,\n",
    "                 dropout: float = 0.1,\n",
    "                 activation=nn.GELU(),\n",
    "                 is_gated: bool = False,\n",
    "                 bias1: bool = True,\n",
    "                 bias2: bool = True,\n",
    "                 bias_gate: bool = True):\n",
    "        super().__init__()\n",
    "        self.linear_1 = nn.Linear(d_model, d_ff, bias1)\n",
    "        self.linear_2 = nn.Linear(d_ff, d_model, bias2)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = activation\n",
    "        self.is_gated = is_gated\n",
    "        if is_gated:\n",
    "            self.gate = nn.Linear(d_model, d_ff, bias_gate)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        g = self.activation(self.linear_1(x))\n",
    "        if self.is_gated:\n",
    "            g = g*self.gate(x)\n",
    "        return self.linear_2(self.dropout(g))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7571fc5a",
   "metadata": {},
   "source": [
    "Encoder Layer & Decoder Layer\n",
    "\n",
    "* Embedding的输入必须是[0,n_vocab-1]的整数，它对应了一个词汇，一个词汇对应了一个可学习的嵌入向量\n",
    "* 在计算embedding时乘以`math.sqrt(self.d_model)`,可以让嵌入向量的方差与维度成正比，提高数值稳定性\n",
    "* 可学习PE可以提高灵活性，但是增加了参数量，且效果依赖于数据，适用于特化，复杂，大规模的模型，GPT系列就使用了可学习编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62343f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmbeddingWithPE(nn.Module):\n",
    "    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Embedding(n_vocab, d_model)\n",
    "        self.d_model = d_model\n",
    "        self.register_buffer(\"positional_encoding\",\n",
    "                             get_positional_encoding(d_model, max_len))\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)\n",
    "        return self.linear(x)*math.sqrt(self.d_model)+pe\n",
    "\n",
    "\n",
    "class EmbeddingWithLearnedPE(nn.Module):\n",
    "    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Embedding(n_vocab, d_model)\n",
    "        self.d_model = d_model\n",
    "        self.positional_encodings = nn.Parameter(\n",
    "            torch.zeros(max_len, 1, d_model), requires_grad=True)\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        pe = self.positional_encodings[:x.shape[0]]\n",
    "        return self.linear(x) * math.sqrt(self.d_model) + pe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a4684e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerLayer(nn.Module):\n",
    "    def __init__(self, *,\n",
    "                 d_model: int,\n",
    "                 self_attn: MultiHeadAttention,\n",
    "                 src_attn: MultiHeadAttention = None,\n",
    "                 feed_forward: FeedForward,\n",
    "                 dropout_prob: float):\n",
    "        \"\"\"这里的MHA没有使用positional encoding，需要被重载\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.size = d_model\n",
    "        self.self_attn = self_attn\n",
    "        self.src_attn = src_attn\n",
    "        self.feed_forward = feed_forward\n",
    "        self.dropout = nn.Dropout(dropout_prob)\n",
    "        self.norm_self_attn = nn.RMSNorm([d_model])\n",
    "        if self.src_attn is not None:\n",
    "            self.norm_src_attn = nn.RMSNorm([d_model])\n",
    "        self.norm_ff = nn.RMSNorm([d_model])\n",
    "        self.is_save_ff_input = False\n",
    "\n",
    "    def forward(self, *, x: torch.Tensor, mask: torch.Tensor,\n",
    "                src: torch.Tensor = None,\n",
    "                src_mask: torch.Tensor = None):\n",
    "        \"\"\"_summary_\n",
    "\n",
    "        Args:\n",
    "            x (torch.Tensor): _description_\n",
    "            mask (torch.Tensor): _description_\n",
    "            src (torch.Tensor): _description_\n",
    "            src_mask (torch.Tensor): _description_\n",
    "        \"\"\"\n",
    "\n",
    "        z = self.norm_self_attn(x)\n",
    "        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)\n",
    "        x = x+self.dropout(self_attn)\n",
    "        # If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs\n",
    "        if src is not None:\n",
    "            z = self.norm_src_attn(z)\n",
    "            src_attn = self.src_attn(\n",
    "                query=z, key=src, value=src, mask=src_mask)\n",
    "            x = x + self.dropout(src_attn)\n",
    "        \n",
    "        # after attn, use ffn\n",
    "        z = self.norm_ff(x)\n",
    "        if self.is_save_ff_input:\n",
    "            self.ff_input = z.clone()\n",
    "        ff = self.feed_forward(z)\n",
    "        x = x+self.dropout(ff)\n",
    "        \n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d9a3e82",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, layer: TransformerLayer, n_layers: int):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.layers = clone_module_list(layer, n_layers)\n",
    "        self.norm = nn.RMSNorm([layer.size])\n",
    "    \n",
    "    def forward(self, x:torch.Tensor, mask:torch.Tensor):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x=x, mask=mask)\n",
    "        return self.norm(x)\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, layer: TransformerLayer, n_layers: int):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.layers = clone_module_list(layer, n_layers)\n",
    "        self.norm = nn.RMSNorm([layer.size])\n",
    "    \n",
    "    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)\n",
    "        return self.norm(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cd9a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, n_vocab: int, d_model: int):\n",
    "        \"\"\"由token embedding生成token\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.projection = nn.Linear(d_model, n_vocab)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.projection(x)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c85addbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderDecoder(nn.Module):\n",
    "    def __init__(self, encoder: Encoder,  decoder: Decoder,\n",
    "                 src_embed: nn.Module, tgt_embed: nn.Module,\n",
    "                 generator: nn.Module):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.src_embed = src_embed\n",
    "        self.tgt_embed = tgt_embed\n",
    "        self.generator = generator\n",
    "\n",
    "        for p in self.parameters():\n",
    "            if p.dim()>1:\n",
    "                nn.init.xavier_uniform_(p)\n",
    "    \n",
    "    def encode(self, src:torch.Tensor, src_mask:torch.Tensor):\n",
    "        return self.encoder(self.src_embed(src), src_mask)\n",
    "    \n",
    "    def decode(self,  memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):\n",
    "        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)\n",
    "            \n",
    "    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):\n",
    "        enc = self.encode(src, src_mask)\n",
    "        return self.decode(enc, src_mask, tgt, tgt_mask)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
